{
 "cells": [
  {
   "cell_type": "markdown",
   "source": [
    "## Lexicase Selection注意事项\n",
    "\n",
    "对于Lexicase Selection，适应度评估需要更改为返回多个误差组成的向量，而不是均方误差（MSE）。这样，Lexicase Selection才能独立考虑每个个体在每个测试样本上的表现，从而提高选择的多样性。"
   ],
   "metadata": {
    "collapsed": false
   },
   "id": "ff6050dfa4dc1b6"
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import math\n",
    "import operator\n",
    "\n",
    "from deap import base, creator, tools, gp\n",
    "\n",
    "\n",
    "# 符号回归\n",
    "def evalSymbReg(individual, pset):\n",
    "    # 编译GP树为函数\n",
    "    func = gp.compile(expr=individual, pset=pset)\n",
    "    \n",
    "    # 使用numpy创建一个向量\n",
    "    x = np.linspace(-10, 10, 100) \n",
    "    \n",
    "    return tuple((func(x) - x**2)**2)\n",
    "\n",
    "\n",
    "# 创建个体和适应度函数，适应度数组大小与数据量相同\n",
    "creator.create(\"FitnessMin\", base.Fitness, weights=(-1.0,) * 100)  # 假设我们有20个数据点\n",
    "creator.create(\"Individual\", gp.PrimitiveTree, fitness=creator.FitnessMin)"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2023-11-07T09:06:58.369619300Z",
     "start_time": "2023-11-07T09:06:58.365066500Z"
    }
   },
   "id": "59cfefc0467c74ad"
  },
  {
   "cell_type": "markdown",
   "source": [
    "### 遗传算子\n",
    "选择算子需要改成Lexicase Selection，其他不需要改变。对于回归问题，需要使用AutomaticEpsilonLexicase。而对于分类问题，则使用Lexicase即可。"
   ],
   "metadata": {
    "collapsed": false
   },
   "id": "956e01e17271daa6"
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "outputs": [],
   "source": [
    "import random\n",
    "\n",
    "# 定义函数集合和终端集合\n",
    "pset = gp.PrimitiveSet(\"MAIN\", arity=1)\n",
    "pset.addPrimitive(operator.add, 2)\n",
    "pset.addPrimitive(operator.sub, 2)\n",
    "pset.addPrimitive(operator.mul, 2)\n",
    "pset.addPrimitive(operator.neg, 1)\n",
    "pset.addEphemeralConstant(\"rand101\", lambda: random.randint(-1, 1))\n",
    "pset.renameArguments(ARG0='x')\n",
    "\n",
    "# 定义遗传编程操作\n",
    "toolbox = base.Toolbox()\n",
    "toolbox.register(\"expr\", gp.genHalfAndHalf, pset=pset, min_=1, max_=2)\n",
    "toolbox.register(\"individual\", tools.initIterate, creator.Individual, toolbox.expr)\n",
    "toolbox.register(\"population\", tools.initRepeat, list, toolbox.individual)\n",
    "toolbox.register(\"compile\", gp.compile, pset=pset)\n",
    "toolbox.register(\"evaluate\", evalSymbReg, pset=pset)\n",
    "toolbox.register(\"select\", tools.selAutomaticEpsilonLexicase)\n",
    "toolbox.register(\"mate\", gp.cxOnePoint)\n",
    "toolbox.register(\"mutate\", gp.mutUniform, expr=toolbox.expr, pset=pset)"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2023-11-07T09:06:58.378447200Z",
     "start_time": "2023-11-07T09:06:58.370620700Z"
    }
   },
   "id": "851794d4d36e3681"
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "   \t      \t                    fitness                    \t                      size                     \n",
      "   \t      \t-----------------------------------------------\t-----------------------------------------------\n",
      "gen\tnevals\tavg    \tgen\tmax  \tmin\tnevals\tstd    \tavg \tgen\tmax\tmin\tnevals\tstd    \n",
      "0  \t20    \t1976.63\t0  \t12544\t0  \t20    \t2791.12\t4.65\t0  \t7  \t3  \t20    \t1.58981\n",
      "1  \t12    \t313.902\t1  \t12100\t0  \t12    \t1323.29\t3.5 \t1  \t7  \t3  \t12    \t1.07238\n",
      "2  \t11    \t1045.82\t2  \t40000\t0  \t11    \t3494.81\t3.9 \t2  \t9  \t3  \t11    \t1.64012\n",
      "3  \t10    \t7999.5 \t3  \t1.21e+06\t0  \t10    \t67582.2\t3.65\t3  \t8  \t3  \t10    \t1.4239 \n",
      "4  \t10    \t7897.13\t4  \t1.1881e+06\t0  \t10    \t66124.7\t3.5 \t4  \t8  \t3  \t10    \t1.20416\n",
      "5  \t12    \t211.535\t5  \t12100     \t0  \t12    \t1097.11\t3.6 \t5  \t9  \t2  \t12    \t1.71464\n",
      "6  \t11    \t104.067\t6  \t10000     \t0  \t11    \t768.451\t3.2 \t6  \t5  \t3  \t11    \t0.6    \n",
      "7  \t11    \t31682.4\t7  \t5.29e+06  \t0  \t11    \t281300 \t3.75\t7  \t9  \t3  \t11    \t1.51245\n",
      "8  \t13    \t418.07 \t8  \t12100     \t0  \t13    \t1509.12\t3.8 \t8  \t7  \t3  \t13    \t1.28841\n",
      "9  \t5     \t423.121\t9  \t12321     \t0  \t5     \t1532.89\t3.95\t9  \t9  \t3  \t5     \t1.59609\n",
      "10 \t11    \t8212.78\t10 \t1.21e+06  \t0  \t11    \t67566  \t3.7 \t10 \t9  \t3  \t11    \t1.45258\n",
      "11 \t16    \t1042.37\t11 \t44100     \t0  \t16    \t4420.55\t3.95\t11 \t9  \t2  \t16    \t1.88348\n",
      "12 \t12    \t39168.4\t12 \t4.84e+06  \t0  \t12    \t277812 \t4.8 \t12 \t11 \t3  \t12    \t2.7313 \n",
      "13 \t13    \t7902.23\t13 \t1.21e+06  \t0  \t13    \t67590.2\t3.9 \t13 \t12 \t3  \t13    \t2.16564\n",
      "14 \t12    \t416.269\t14 \t40000     \t0  \t12    \t3073.8 \t3.1 \t14 \t5  \t3  \t12    \t0.43589\n",
      "15 \t9     \t1.70034\t15 \t100       \t0  \t9     \t10.0586\t3.1 \t15 \t5  \t3  \t9     \t0.43589\n",
      "16 \t15    \t312.252\t16 \t12100     \t0  \t15    \t1329.2 \t3.2 \t16 \t5  \t2  \t15    \t0.812404\n",
      "17 \t11    \t105.768\t17 \t10000     \t0  \t11    \t768.286\t3.9 \t17 \t9  \t3  \t11    \t1.60935 \n",
      "18 \t13    \t941.757\t18 \t40000     \t0  \t13    \t3437.72\t4.5 \t18 \t9  \t3  \t13    \t1.93649 \n",
      "19 \t17    \t8310.05\t19 \t1.44e+06  \t0  \t17    \t73988.5\t4.05\t19 \t7  \t3  \t17    \t1.28355 \n",
      "20 \t14    \t419.67 \t20 \t12100     \t0  \t14    \t1508.3 \t4.45\t20 \t7  \t3  \t14    \t1.46544 \n",
      "mul(x, x)\n"
     ]
    }
   ],
   "source": [
    "import numpy\n",
    "from deap import algorithms\n",
    "\n",
    "# 定义统计指标\n",
    "stats_fit = tools.Statistics(lambda ind: ind.fitness.values)\n",
    "stats_size = tools.Statistics(len)\n",
    "mstats = tools.MultiStatistics(fitness=stats_fit, size=stats_size)\n",
    "mstats.register(\"avg\", numpy.mean)\n",
    "mstats.register(\"std\", numpy.std)\n",
    "mstats.register(\"min\", numpy.min)\n",
    "mstats.register(\"max\", numpy.max)\n",
    "\n",
    "# 使用默认算法\n",
    "population = toolbox.population(n=20)\n",
    "hof = tools.HallOfFame(1)\n",
    "pop, log  = algorithms.eaSimple(population=population,\n",
    "                           toolbox=toolbox, cxpb=0.5, mutpb=0.2, ngen=20, stats=mstats, halloffame=hof, verbose=True)\n",
    "print(str(hof[0]))\n"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2023-11-07T09:07:09.006767300Z",
     "start_time": "2023-11-07T09:06:58.377448600Z"
    }
   },
   "id": "515b587d4f8876ea"
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
